Skip to content

Conversation

@hfrick
Copy link
Member

@hfrick hfrick commented Apr 8, 2021

Closes #460

The predictions previously ignored the event_level, they respect it now.

library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(mlbench)
library(xgboost)
#> 
#> Attaching package: 'xgboost'
#> The following object is masked from 'package:dplyr':
#> 
#>     slice
library(reprex)


data("PimaIndiansDiabetes")
df <- PimaIndiansDiabetes %>%
  mutate(diabetes = fct_relevel(diabetes, 'pos'))


# xgboost
x <- as.matrix(df[,-ncol(df)])
y <- -as.numeric(df$diabetes)+2

xgbmat <- xgb.DMatrix(data = x, label = y)

set.seed(24)
xgb_model <- xgb.train(params = list(eta = 0.3, max_depth = 3, gamma = 0, 
                                     colsample_bytree = 1, min_child_weight = 1, subsample = 1, 
                                     objective = "binary:logistic"),  watchlist = list('train' = xgbmat),
                       verbose = 1, data = xgbmat, nrounds = 10, eval_metric = "auc", 
                       nthread = 1)
#> [1]  train-auc:0.834787 
#> [2]  train-auc:0.853881 
#> [3]  train-auc:0.873048 
#> [4]  train-auc:0.875000 
#> [5]  train-auc:0.882873 
#> [6]  train-auc:0.891716 
#> [7]  train-auc:0.896067 
#> [8]  train-auc:0.900414 
#> [9]  train-auc:0.904313 
#> [10] train-auc:0.906937

tidy_model <- boost_tree(trees = 10, 
                         tree_depth = 3) %>%
  set_engine('xgboost', 
             eval_metric = 'auc',
             event_level = "first",
             verbose = 1) %>%
  set_mode('classification')

set.seed(24)
tidy_model_fitted <- tidy_model %>% 
  fit(diabetes ~ . , data = df)
#> [1]  training-auc:0.834787 
#> [2]  training-auc:0.853881 
#> [3]  training-auc:0.873048 
#> [4]  training-auc:0.875000 
#> [5]  training-auc:0.882873 
#> [6]  training-auc:0.891716 
#> [7]  training-auc:0.896067 
#> [8]  training-auc:0.900414 
#> [9]  training-auc:0.904313 
#> [10] training-auc:0.906937


# xgboost predictions
predict(xgb_model, newdata = xgbmat) %>%
  as_tibble() %>%
  select(.pred_pos = value) %>%
  mutate(.pred_meg = 1 - .pred_pos)
#> # A tibble: 768 x 2
#>    .pred_pos .pred_meg
#>        <dbl>     <dbl>
#>  1    0.586      0.414
#>  2    0.106      0.894
#>  3    0.778      0.222
#>  4    0.0541     0.946
#>  5    0.613      0.387
#>  6    0.0963     0.904
#>  7    0.163      0.837
#>  8    0.370      0.630
#>  9    0.834      0.166
#> 10    0.344      0.656
#> # … with 758 more rows

# tidy model predicts are associated with the right class now
predict(tidy_model_fitted, new_data = df, type = 'prob')
#> # A tibble: 768 x 2
#>    .pred_pos .pred_neg
#>        <dbl>     <dbl>
#>  1    0.586      0.414
#>  2    0.106      0.894
#>  3    0.778      0.222
#>  4    0.0541     0.946
#>  5    0.613      0.387
#>  6    0.0963     0.904
#>  7    0.163      0.837
#>  8    0.370      0.630
#>  9    0.834      0.166
#> 10    0.344      0.656
#> # … with 758 more rows

xgb_predictions <- 
  predict(xgb_model, newdata = xgbmat) %>%
  as_tibble() %>%
  select(.pred_pos = value) %>%
  mutate(.pred_meg = 1 - .pred_pos)

tidy_predictions <- predict(tidy_model_fitted, new_data = df, type = 'prob')


# Correct prediction to class
roc_auc_vec(df$diabetes, xgb_predictions$.pred_pos)
#> [1] 0.9069366

# now correct prediction to class
roc_auc_vec(df$diabetes, tidy_predictions$.pred_pos)
#> [1] 0.9069366

Created on 2021-04-08 by the reprex package (v2.0.0)

@hfrick hfrick requested a review from DavisVaughan April 8, 2021 10:51
@hfrick hfrick requested a review from DavisVaughan April 9, 2021 12:42
@hfrick hfrick merged commit 228a6dc into master Apr 13, 2021
@hfrick hfrick deleted the pred-event-level branch April 13, 2021 17:28
@github-actions
Copy link
Contributor

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Apr 28, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

predict() associates prediction with wrong class when event_level = 'first'

3 participants